{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Part 1. Add speech ID and filter text\n",
    "### Read data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json \n",
    "sft_data_path = 'mgm_instruction_clear.json'\n",
    "sft_data = json.load(open(sft_data_path, 'r'))\n",
    "print(len(sft_data))\n",
    "print(sft_data[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Add unique speech id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "new_sft_data_list = []\n",
    "for cnt, item in enumerate(sft_data):\n",
    "    new_item = copy.deepcopy(item)\n",
    "    new_item[\"speech_id\"] = \"lyra_mm_\" + str(cnt)\n",
    "    new_sft_data_list.append(new_item)\n",
    "\n",
    "print(len(new_sft_data_list))\n",
    "print(new_sft_data_list[0])\n",
    "\n",
    "# json to file\n",
    "save_path = 'mgm_instruction_clear_with_speechid.json'\n",
    "with open(save_path, 'w', encoding='utf-8') as json_file:\n",
    "    json.dump(new_sft_data_list, json_file, ensure_ascii=False, indent=4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Extract text to convert"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "def is_en_str(string):\n",
    "    pattern = re.compile(r'[a-zA-Z0-9\\s.,!?\\'\"()-]*')\n",
    "    if pattern.fullmatch(string):\n",
    "        return True\n",
    "    else:\n",
    "        return False\n",
    "\n",
    "def replace_choice(string):\n",
    "    return string.replace(\"\\nA.\", \"\\nOption A is \").replace(\"\\nB.\", \"\\nOption B is \").replace(\"\\nC.\", \"\\nOption C is \").replace(\"\\nD.\", \"\\nOption D is \")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_text_list = []\n",
    "one_round = 0\n",
    "for cnt, item in enumerate(sft_data):\n",
    "    if item['conversations'][-2]['from'] != 'human':\n",
    "        continue\n",
    "    \n",
    "    question = item['conversations'][-2]['value'].replace('<image>\\n', '').replace('\\n<image>', '').replace('<image>', '')\n",
    "    question = replace_choice(question)\n",
    "    new_item = {\n",
    "        'speech_id': item['speech_id'],\n",
    "        'question': question,\n",
    "        'conv_length': len(item['conversations'])\n",
    "    }\n",
    "    new_text_list.append(new_item)\n",
    "\n",
    "print(one_round)\n",
    "print(len(new_text_list))\n",
    "print(new_text_list[0])\n",
    "\n",
    "# json to file\n",
    "save_path = 'sft_ques_total_with_modified_options.json'\n",
    "with open(save_path, 'w', encoding='utf-8') as json_file:\n",
    "    json.dump(new_text_list, json_file, ensure_ascii=False, indent=4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Special process for OCR data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ocr_ques_list = []\n",
    "for ques in new_text_list:\n",
    "    if \"\\nReference OCR\" in ques[\"question\"]: # and \"\\nAnswer\" in ques[\"question\"]:\n",
    "        new_ques = copy.deepcopy(ques)\n",
    "        new_ques[\"question\"] = new_ques[\"question\"].split(\"\\nReference OCR\")[0]\n",
    "        ocr_ques_list.append(new_ques)\n",
    "print(len(ocr_ques_list), ocr_ques_list[:2])\n",
    "\n",
    "# json to file\n",
    "save_path = 'sft_ques_textOCR.json'\n",
    "with open(save_path, 'w', encoding='utf-8') as json_file:\n",
    "    json.dump(ocr_ques_list, json_file, ensure_ascii=False, indent=4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Part 2. Reorganize SFT annotation data\n",
    "### tag all successfully generated speeches"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mgm_sft_speech_path = \"Lyra_MM\"\n",
    "mgm_sft_speech_ids = {}\n",
    "import os\n",
    "for speech_path in os.listdir(mgm_sft_speech_path):\n",
    "    mgm_sft_speech_ids[speech_path.split('.')[0]] = 1\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "short_path = \"Lyra_MM\"\n",
    "sft_data_path= 'mgm_instruction_clear_with_speechid.json'\n",
    "sft_data = json.load(open(sft_data_path))\n",
    "new_text_list = []\n",
    "for cnt, item in enumerate(sft_data):\n",
    "    new_item = copy.deepcopy(item)\n",
    "    if new_item['speech_id'] in mgm_sft_speech_ids:\n",
    "        text_question = new_item['conversations'][-2]['value']            \n",
    "        if len(new_item['conversations']) == 2:\n",
    "            if 'image' in new_item:\n",
    "                new_item['conversations'][-2]['value'] = \"<image>\\n<speech>\"\n",
    "            else:\n",
    "                new_item['conversations'][-2]['value'] = \"<speech>\"\n",
    "        else:\n",
    "            new_item['conversations'][-2]['value'] = \"<speech>\"\n",
    "            \n",
    "        if \"\\nReference OCR\" in text_question: # Deal with OCR data\n",
    "            new_item['conversations'][-2]['value'] = new_item['conversations'][-2]['value'] + \"\\nReference OCR\" + text_question.split(\"\\nReference OCR\")[1]\n",
    "            text_question = text_question.split(\"\\nReference OCR\")[0]\n",
    "            \n",
    "        text_question = replace_choice(text_question)\n",
    "\n",
    "        if len(text_question) >= 10:\n",
    "            new_item['speech_asr'] = text_question.replace('<image>\\n', '').replace('\\n<image>', '').replace('<image>', '')\n",
    "        new_item['speech'] = os.path.join(short_path, \"{}.mp3\".format(new_item['speech_id']))\n",
    "    new_text_list.append(new_item)\n",
    "\n",
    "print(len(new_text_list))\n",
    "print(new_text_list[0])\n",
    "\n",
    "# json to file\n",
    "save_path = 'mgm_instruction_clear_with_speech.json'\n",
    "with open(save_path, 'w', encoding='utf-8') as json_file:\n",
    "    json.dump(new_text_list, json_file, ensure_ascii=False, indent=4)"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
